{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e509b5de",
   "metadata": {},
   "source": [
    "## Factor Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aed2998e",
   "metadata": {},
   "source": [
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "    \n",
    "Factor graphs are a powerful and, in my opinion, underused probabilistic model. Potentially, one reason that they are underutilized is that they are not as conceptually intuitive as Bayesian networks or mixture models. Here, we will alleviate any confusion.\n",
    "\n",
    "Factor graphs are similar to Bayesian networks in that they consist of a set of probability distributions and a graph connecting them. However, unlike Bayesian networks, this graph is bipartate, and the set of probability distributions is not simply the variables in a data set; rather, the distributions are the union of the marginal distributions of each variable and the factor distributions, which are usually multivariate. In the graph, marginal distributions are on one side, factor distributions are on the other side, and the undirected edges only factor distributions with the marginal distributions of the variables that comprise it.\n",
    "\n",
    "One way of thinking about this is to start with a Bayesian network. If you convert all the conditional probability distributions into joint probability distributions and keep the univariate distributions as is, you now have your set of factor distributions. Then, for each variable in your network, you add in a marginal distribution. This gives you the set of nodes in your factor graph. Finally, you add an edge between each factor distribution and the marginal distributions corresponding to the variables that make up that joint distribution. When the factor is a univariate distribution, you just add a single edge between that factor and its marginal distribution.\n",
    "\n",
    "Once the factor graph is instantiated, inference involves an iterative message passing algorithm. In the first step, marginal distributions emit messages which are copies of themselves. Then, in the second step, factor distributions take in estimates of each marginal distribution, combine them with the joint probability parameters, and emit new estimates for each variable back to each marginal distribution. Finally, in the third step, the marginal distributions take in these estimates, average them together, and emit the new estimates back to each variable. The second and third steps are repeated until convergence.\n",
    "\n",
    "Note: although parameter fitting is implemented for factor graphs, structure learning is not yet supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2327088e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n",
      "torch      : 1.13.0\n",
      "pomegranate: 1.0.0\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-208-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import seaborn; seaborn.set_style('whitegrid')\n",
    "import torch\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "517ce156",
   "metadata": {},
   "source": [
    "### Initialization and Fitting\n",
    "\n",
    "A factor graph is comprised of two sets of distributions: the factor distributions, which are a mixture of joint probabilities and univariate probabilities, and the marginal distributions, which are one univariate probability distribution per variable in your data. Marginal distributions are usually set to the uniform distribution, whereas the factor distributions encode statistics about the underlying data. These two sets of distributions are connected using an undirected unweighted biparte graph such that factors can be connected to marginals but not to other factors. Each variable in your data must have a corresponding marginal distribution and appear in at least one distribution on the factor side. Each variable can occur in multiple factor distributions.\n",
    "\n",
    "Let's start off by implementing the simplest factor graph: two variables joined in a joint categorical distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2bd14a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pomegranate.distributions import Categorical\n",
    "from pomegranate.distributions import JointCategorical\n",
    "from pomegranate.factor_graph import FactorGraph\n",
    "\n",
    "m1 = Categorical([[0.5, 0.5]])\n",
    "m2 = Categorical([[0.5, 0.5]])\n",
    "\n",
    "f1 = JointCategorical([[0.1, 0.3], [0.2, 0.4]])\n",
    "\n",
    "model = FactorGraph([f1], [m1, m2], [(m1, f1), (m2, f1)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4e137fe",
   "metadata": {},
   "source": [
    "Note that the order that edges are added to the model denote the ordering of variables in the `JointCategorical` distribution, in that the first edge containing `f1` (just the first edge in this case) indicates that the parent of that edge is the first dimension in the factor tensor. \n",
    "\n",
    "Similarly, the ordering that marginal distributions are added to the model correspond to the ordering of the variables in your data. Specifically, `m1` covers the first column, `m2` covers the second column, and accordingly `f1` is made up of data from the first and second columns. The ordering of variables in a factor do not need to be sorted. If the edges were added as `[(m2, f1), (m1, f1)]`, a valid factor graph would be produced with the only difference being that the first dimension in the `f1` tensor would correspond to the second column data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d2bf514",
   "metadata": {},
   "source": [
    "If you are constructing the factor graph in a more programmatic way you might prefer to use the `add_edge`, `add_marginal`, and `add_factor` methods to build the graph slowly. These methods are used internally when these values are passed into the initialization but can also be called themselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "aa9779e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = FactorGraph()\n",
    "model.add_factor(f1)\n",
    "\n",
    "model.add_marginal(m1)\n",
    "model.add_marginal(m2)\n",
    "\n",
    "model.add_edge(m1, f1)\n",
    "model.add_edge(m2, f1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0569120b",
   "metadata": {},
   "source": [
    "If you have data, you can then train the factor parameters in the same way that you can fit other parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dcf33396",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [1, 0],\n",
       "        [1, 0],\n",
       "        [0, 0],\n",
       "        [1, 1],\n",
       "        [0, 1],\n",
       "        [1, 0],\n",
       "        [1, 1],\n",
       "        [0, 0],\n",
       "        [1, 1],\n",
       "        [0, 1]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.randint(2, size=(11, 2))\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a3955630",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.2727, 0.1818],\n",
       "        [0.2727, 0.2727]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X)\n",
    "model.factors[0].probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7116ab8a",
   "metadata": {},
   "source": [
    "Importantly, these updates DO NOT change the marginal distribution values. This is because the factor values encode everything that you can learn from previous data (the likelihood function), whereas the marginal distributions represent your prior probability across symbols for that variable. When doing inference when you know the symbol coming from some variable, you set the marginal distributions to be that value 100%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5dc9145a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.5000, 0.5000]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.marginals[0].probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c26256c0",
   "metadata": {},
   "source": [
    "### Probabilities and Log Probabilities\n",
    "\n",
    "Like other methods, one can calculate the probability of each example given the model. Unlike other models, this probability is factorized across the factors and the marginal distributions to be $P(X) = \\prod\\limits_{i=0}^{f}  P(X_{p_i} | F_i) \\prod\\limits_{i=0}^{d} P(X_i | M_i)$ when there are f factors and d dimensions to your data. Basically, you are evaluating the probability of the data under the likelihood of the model (the product over the factors) and multiplying that by the probability of the data under the marginals (the product over the marginals). Without any prior information, the marginal probability is constant and can be ignored. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ecdf7a22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0682])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.probability(X[:1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ea3c4e1",
   "metadata": {},
   "source": [
    "Remember that the above value includes multiplication by marginal distributions, so will probably not directly match any of those in the factor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "533f2276",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2.6856, -2.6856, -2.6856, -2.6856, -2.6856, -3.0910, -2.6856, -2.6856,\n",
       "        -2.6856, -2.6856, -3.0910])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.log_probability(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1f1bd27",
   "metadata": {},
   "source": [
    "### Predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "215dabac",
   "metadata": {},
   "source": [
    "Similarly to Bayesian networks, factor graphs can make predictions for missing values in data sets. In fact, Bayesian networks and Markov networks both frequently construct factor graphs in the backend to do the actual inference. These approaches use the sum-product algorithm, also called loopy belief propagation. The algorithm works essentially as follows:\n",
    "\n",
    "\n",
    "- Initialize messages TO each factor FROM each marginal that is a copy of the marginal distribution\n",
    "- For each factor, iterate over the marginal distributions it is connected to and calculate the factor's marginal distribution using all OTHER messages to it, ignoring the message from the marginal distribution of interest, and send its belief of what that distribution should be back to it\n",
    "- For each marginal, multiply all incoming messages together to get an estimate of what the new marginal value should be.\n",
    "\n",
    "After the messages converge the new marginal distributions are returned representing the probabilities of each distribution\n",
    "\n",
    "When one knows what some of the variables should be, such as when they have an incomplete matrix, you can set the initial marginal probability distributions (not the messages, the actual distributions) to be consistent with that (i.e., in the categorical case, observed a `2` means assigning 100% of the probability to that category instead of using uniform probabilities).\n",
    "\n",
    "Specifying an incomplete matrix is done by using a `torch.masked.MaskedTensor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f73cac6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_89475/916229149.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  X_torch = torch.tensor(X[:4])\n",
      "/home/jmschr/anaconda3/lib/python3.9/site-packages/torch/masked/maskedtensor/core.py:156: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.\n",
      "  warnings.warn((\"The PyTorch API of MaskedTensors is in prototype stage \"\n"
     ]
    }
   ],
   "source": [
    "X_torch = torch.tensor(X[:4])\n",
    "mask = torch.tensor([[True, False],\n",
    "                     [False, True],\n",
    "                     [True, True],\n",
    "                     [False, False]])\n",
    "\n",
    "X_masked = torch.masked.MaskedTensor(X_torch, mask=mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c982e1",
   "metadata": {},
   "source": [
    "The mask indicates with a `True` value which indices are known. It does not matter what data value is taken when the mask is `False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eb4bb44e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [0, 0],\n",
       "        [1, 0],\n",
       "        [1, 0]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff7591e",
   "metadata": {},
   "source": [
    "This is a somewhat discrete view of the result of the sum-product algorithm though because a full distribution is calculated. If you would like to get the entire predicted probabilities you can do that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "948b92bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[1.0000, 0.0000],\n",
       "         [0.5000, 0.5000],\n",
       "         [0.0000, 1.0000],\n",
       "         [0.4545, 0.5455]]),\n",
       " tensor([[0.6000, 0.4000],\n",
       "         [1.0000, 0.0000],\n",
       "         [1.0000, 0.0000],\n",
       "         [0.5455, 0.4545]])]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd41a294",
   "metadata": {},
   "source": [
    "Note that the output here is a list with two dimensions. Each tensor in the list corresponds to a dimension in the underlying data and the tensors have shape `(n_examples, n_categories)` when using categorical data types. Basically, the first row in the first tensor includes the probabilities that element `X_masked[0, 0]` takes value 0 and 1. Because `X_masked[0, 0]` is an observed value, i.e., has a mask value of `True`, this means that the probability is clamped to 1 at the observed value.\n",
    "\n",
    "If we only want the log probabilities we can calculate those as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "86c4906f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[ 0.0000,    -inf],\n",
       "         [-0.6931, -0.6931],\n",
       "         [   -inf,  0.0000],\n",
       "         [-0.7885, -0.6061]]),\n",
       " tensor([[-0.5108, -0.9163],\n",
       "         [ 0.0000,    -inf],\n",
       "         [ 0.0000,    -inf],\n",
       "         [-0.6061, -0.7885]])]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_log_proba(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2105983d",
   "metadata": {},
   "source": [
    "### Summarize\n",
    "\n",
    "Factor graphs can make use of the summarization API for training just like the other models. Basically, you can summarize one or more batches of data using the `summarize` method and then call the `from_summaries` method to update the parameters of the distribution. This will only update the parameters of the factors and leave the marginal distributions the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2a693544",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.summarize(X[:3])\n",
    "model.from_summaries()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea3b601e",
   "metadata": {},
   "source": [
    "Then, we can check the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "554ff220",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.3333, 0.0000],\n",
       "        [0.6667, 0.0000]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.factors[0].probs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
